// This file is licensed under the Elastic License 2.0. Copyright 2021 StarRocks Limited.

#include "aggregate_base_operator.h"

#include "exprs/anyval_util.h"
#include "gutil/strings/substitute.h"
namespace starrocks::pipeline {

AggregateBaseOperator::AggregateBaseOperator(int32_t id, std::string name, int32_t plan_node_id, const TPlanNode& tnode)
        : Operator(id, name, plan_node_id),
          _tnode(tnode),
          _needs_finalize(tnode.agg_node.need_finalize),
          _streaming_preaggregation_mode(tnode.agg_node.streaming_preaggregation_mode),
          _intermediate_tuple_id(tnode.agg_node.intermediate_tuple_id),
          _intermediate_tuple_desc(nullptr),
          _output_tuple_id(tnode.agg_node.output_tuple_id),
          _output_tuple_desc(nullptr) {}

Status AggregateBaseOperator::prepare(RuntimeState* state) {
    RETURN_IF_ERROR(Operator::prepare(state));

    // TODO(hcf) which obj pool should I use
    _pool = state->obj_pool();
    _rows_returned_counter = ADD_COUNTER(get_runtime_profile(), "RowsReturned", TUnit::UNIT);

    // TODO(hcf) following copy from AggregateBaseNode::init
    RETURN_IF_ERROR(Expr::create_expr_trees(_pool, _tnode.agg_node.grouping_exprs, &_group_by_expr_ctxs));
    // add profile attributes
    if (_tnode.agg_node.__isset.sql_grouping_keys) {
        get_runtime_profile()->add_info_string("GroupingKeys", _tnode.agg_node.sql_grouping_keys);
    }
    if (_tnode.agg_node.__isset.sql_aggregate_functions) {
        get_runtime_profile()->add_info_string("AggregateFunctions", _tnode.agg_node.sql_aggregate_functions);
    }

    bool has_outer_join_child = _tnode.agg_node.__isset.has_outer_join_child && _tnode.agg_node.has_outer_join_child;
    VLOG_ROW << "has_outer_join_child " << has_outer_join_child;

    size_t group_by_size = _group_by_expr_ctxs.size();
    _group_by_columns.resize(group_by_size);
    _group_by_types.resize(group_by_size);
    for (size_t i = 0; i < group_by_size; ++i) {
        TExprNode expr = _tnode.agg_node.grouping_exprs[i].nodes[0];
        _group_by_types[i].result_type = TypeDescriptor::from_thrift(expr.type);
        _group_by_types[i].is_nullable = expr.is_nullable || has_outer_join_child;
        _has_nullable_key = _has_nullable_key || _group_by_types[i].is_nullable;
        VLOG_ROW << "group by column " << i << " result_type " << _group_by_types[i].result_type << " is_nullable "
                 << expr.is_nullable;
    }
    VLOG_ROW << "has_nullable_key " << _has_nullable_key;

    _tmp_agg_states.resize(config::vector_chunk_size);

    size_t agg_size = _tnode.agg_node.aggregate_functions.size();
    _agg_fn_ctxs.resize(agg_size);
    _agg_functions.resize(agg_size);
    _agg_expr_ctxs.resize(agg_size);
    _agg_intput_columns.resize(agg_size);
    _agg_input_raw_columns.resize(agg_size);
    _agg_fn_types.resize(agg_size);
    _agg_states_offsets.resize(agg_size);
    _is_merge_funcs.resize(agg_size);

    for (int i = 0; i < agg_size; ++i) {
        const TExpr& desc = _tnode.agg_node.aggregate_functions[i];
        const TFunction& fn = desc.nodes[0].fn;
        _is_merge_funcs[i] = _tnode.agg_node.aggregate_functions[i].nodes[0].agg_expr.is_merge_agg;
        VLOG_ROW << fn.name.function_name << " is arg nullable " << desc.nodes[0].has_nullable_child;
        VLOG_ROW << fn.name.function_name << " is result nullable " << desc.nodes[0].is_nullable;
        if (fn.name.function_name == "count") {
            {
                bool is_input_nullable =
                        !fn.arg_types.empty() && (has_outer_join_child || desc.nodes[0].has_nullable_child);
                auto* func = vectorized::get_aggregate_function("count", TYPE_BIGINT, TYPE_BIGINT, is_input_nullable);
                _agg_functions[i] = func;
            }
            std::vector<FunctionContext::TypeDesc> arg_typedescs;
            _agg_fn_types[i] = {TypeDescriptor(TYPE_BIGINT), TypeDescriptor(TYPE_BIGINT), arg_typedescs, false, false};
            // count(*) no input column, we manually resize it to 1 to process count(*)
            // like other agg function.
            _agg_intput_columns[i].resize(1);
        } else {
            TypeDescriptor return_type = TypeDescriptor::from_thrift(fn.ret_type);
            TypeDescriptor serde_type = TypeDescriptor::from_thrift(fn.aggregate_fn.intermediate_type);

            // collect arg_typedescs for aggregate function.
            std::vector<FunctionContext::TypeDesc> arg_typedescs;
            for (auto& type : fn.arg_types) {
                arg_typedescs.push_back(AnyValUtil::column_type_to_type_desc(TypeDescriptor::from_thrift(type)));
            }

            TypeDescriptor arg_type = TypeDescriptor::from_thrift(fn.arg_types[0]);
            // Because intersect_count has more two input types.
            // intersect_count's first argument's type is alwasy Bitmap,
            // So we get its second arguments type as input.
            if (fn.name.function_name == "intersect_count") {
                arg_type = TypeDescriptor::from_thrift(fn.arg_types[1]);
            }

            bool is_input_nullable = has_outer_join_child || desc.nodes[0].has_nullable_child;
            auto* func = vectorized::get_aggregate_function(fn.name.function_name, arg_type.type, return_type.type,
                                                            is_input_nullable);
            if (func == nullptr) {
                return Status::InternalError(
                        strings::Substitute("Invalid agg function plan: $0", fn.name.function_name));
            }
            VLOG_ROW << "get agg function " << func->get_name() << " serde_type " << serde_type << " return_type "
                     << return_type;
            _agg_functions[i] = func;
            _agg_fn_types[i] = {return_type, serde_type, arg_typedescs, is_input_nullable, desc.nodes[0].is_nullable};
        }

        int node_idx = 0;
        for (int j = 0; j < desc.nodes[0].num_children; ++j) {
            ++node_idx;
            Expr* expr = nullptr;
            ExprContext* ctx = nullptr;
            RETURN_IF_ERROR(Expr::create_tree_from_thrift(_pool, desc.nodes, nullptr, &node_idx, &expr, &ctx));
            _agg_expr_ctxs[i].emplace_back(ctx);
        }
        _agg_intput_columns[i].resize(desc.nodes[0].num_children);
        _agg_input_raw_columns[i].resize(desc.nodes[0].num_children);
    }

    // compute agg state total size and offsets
    for (int i = 0; i < _agg_fn_ctxs.size(); ++i) {
        _agg_states_offsets[i] = _agg_states_total_size;
        _agg_states_total_size += _agg_functions[i]->size();
        _max_agg_state_align_size = std::max(_max_agg_state_align_size, _agg_functions[i]->alignof_size());

        // If not the last aggregate_state, we need pad it so that next aggregate_state will be aligned.
        if (i + 1 < _agg_fn_ctxs.size()) {
            size_t next_state_align_size = _agg_functions[i + 1]->alignof_size();
            // Extend total_size to next alignment requirement
            // Add padding by rounding up '_agg_states_total_size' to be a multiplier of next_state_align_size.
            _agg_states_total_size = (_agg_states_total_size + next_state_align_size - 1) / next_state_align_size *
                                     next_state_align_size;
        }
    }

    _is_only_group_by_columns = _agg_expr_ctxs.empty() && !_group_by_expr_ctxs.empty();

    // TODO(hcf) following copy from AggregateBaseNode::prepare
    _get_results_timer = ADD_TIMER(get_runtime_profile(), "GetResultsTime");
    _iter_timer = ADD_TIMER(get_runtime_profile(), "ResultIteratorTime");
    _agg_append_timer = ADD_TIMER(get_runtime_profile(), "ResultAggAppendTime");
    _group_by_append_timer = ADD_TIMER(get_runtime_profile(), "ResultGroupByAppendTime");
    //TODO: split agg_compute_timer to cunstruct_ht_time + agg_func_compute_time
    _agg_compute_timer = ADD_TIMER(get_runtime_profile(), "AggComputeTime");
    _streaming_timer = ADD_TIMER(get_runtime_profile(), "StreamingTime");
    _expr_compute_timer = ADD_TIMER(get_runtime_profile(), "ExprComputeTime");
    _expr_release_timer = ADD_TIMER(get_runtime_profile(), "ExprReleaseTime");

    _input_row_count = ADD_COUNTER(get_runtime_profile(), "InputRowCount", TUnit::UNIT);
    _hash_table_size = ADD_COUNTER(get_runtime_profile(), "HashTableSize", TUnit::UNIT);
    _pass_through_row_count = ADD_COUNTER(get_runtime_profile(), "PassThroughRowCount", TUnit::UNIT);

    SCOPED_TIMER(get_runtime_profile()->total_time_counter());

    _intermediate_tuple_desc = state->desc_tbl().get_tuple_descriptor(_intermediate_tuple_id);
    _output_tuple_desc = state->desc_tbl().get_tuple_descriptor(_output_tuple_id);
    DCHECK_EQ(_intermediate_tuple_desc->slots().size(), _output_tuple_desc->slots().size());

    // TODO(hcf) force annotation
    // RETURN_IF_ERROR(Expr::prepare(_group_by_expr_ctxs, state, child(0)->row_desc(), expr_mem_tracker()));

    // for (const auto& ctx : _agg_expr_ctxs) {
    //     RETURN_IF_ERROR(Expr::prepare(ctx, state, child(0)->row_desc(), expr_mem_tracker()));
    // }

    _mem_pool = std::make_unique<MemPool>(get_memtracker());

    // Initial for FunctionContext of every aggregate functions
    for (int i = 0; i < _agg_fn_ctxs.size(); ++i) {
        _agg_fn_ctxs[i] = FunctionContextImpl::create_context(
                state, _mem_pool.get(), AnyValUtil::column_type_to_type_desc(_agg_fn_types[i].result_type),
                _agg_fn_types[i].arg_typedescs, 0, false);
        state->obj_pool()->add(_agg_fn_ctxs[i]);
    }

    if (_group_by_expr_ctxs.empty()) {
        _single_agg_state = _mem_pool->allocate_aligned(_agg_states_total_size, _max_agg_state_align_size);
        for (int i = 0; i < _agg_functions.size(); i++) {
            _agg_functions[i]->create(_single_agg_state + _agg_states_offsets[i]);
        }
        if (_agg_expr_ctxs.empty()) {
            return Status::InternalError("Invalid agg query plan");
        }
    }

    // For SQL: select distinct id from table or select id from from table group by id;
    // we don't need to allocate memory for agg states.
    if (_is_only_group_by_columns) {
        _init_agg_hash_variant(_hash_set_variant);
    } else {
        _init_agg_hash_variant(_hash_map_variant);
    }

    return Status::OK();
}

void AggregateBaseOperator::_evaluate_const_columns(int i) {
    // used for const columns.
    std::vector<ColumnPtr> const_columns;
    const_columns.reserve(_agg_expr_ctxs[i].size());
    for (int j = 0; j < _agg_expr_ctxs[i].size(); ++j) {
        const_columns.emplace_back(_agg_expr_ctxs[i][j]->root()->evaluate_const(_agg_expr_ctxs[i][j]));
    }
    _agg_fn_ctxs[i]->impl()->set_constant_columns(const_columns);
}

template <typename HashVariantType>
void AggregateBaseOperator::_init_agg_hash_variant(HashVariantType& hash_variant) {
    auto type = _aggr_phase == vectorized::AggrPhase1 ? HashVariantType::Type::phase1_slice
                                                      : HashVariantType::Type::phase2_slice;
    if (_has_nullable_key) {
        switch (_group_by_expr_ctxs.size()) {
        case 0:
            break;
        case 1: {
            auto group_by_expr = _group_by_expr_ctxs[0];
            switch (group_by_expr->root()->type().type) {
            case TYPE_TINYINT: {
                type = _aggr_phase == vectorized::AggrPhase1 ? HashVariantType::Type::phase1_null_int8
                                                             : HashVariantType::Type::phase2_null_int8;
                break;
            }
            case TYPE_SMALLINT: {
                type = _aggr_phase == vectorized::AggrPhase1 ? HashVariantType::Type::phase1_null_int16
                                                             : HashVariantType::Type::phase2_null_int16;
                break;
            }
            case TYPE_INT: {
                type = _aggr_phase == vectorized::AggrPhase1 ? HashVariantType::Type::phase1_null_int32
                                                             : HashVariantType::Type::phase2_null_int32;
                break;
            }
            case TYPE_BIGINT: {
                type = _aggr_phase == vectorized::AggrPhase1 ? HashVariantType::Type::phase1_null_int64
                                                             : HashVariantType::Type::phase2_null_int64;
                break;
            }
            case TYPE_DATE: {
                type = _aggr_phase == vectorized::AggrPhase1 ? HashVariantType::Type::phase1_null_date
                                                             : HashVariantType::Type::phase2_null_date;
                break;
            }
            case TYPE_DATETIME: {
                type = _aggr_phase == vectorized::AggrPhase1 ? HashVariantType::Type::phase1_null_timestamp
                                                             : HashVariantType::Type::phase2_null_timestamp;
                break;
            }
            case TYPE_CHAR:
            case TYPE_VARCHAR: {
                type = _aggr_phase == vectorized::AggrPhase1 ? HashVariantType::Type::phase1_null_string
                                                             : HashVariantType::Type::phase2_null_string;
                break;
            }
            default: {
                type = _aggr_phase == vectorized::AggrPhase1 ? HashVariantType::Type::phase1_slice
                                                             : HashVariantType::Type::phase2_slice;
                break;
            }
            }
            break;
        }
        default: {
            type = _aggr_phase == vectorized::AggrPhase1 ? HashVariantType::Type::phase1_slice
                                                         : HashVariantType::Type::phase2_slice;
            break;
        }
        }
    } else {
        switch (_group_by_expr_ctxs.size()) {
        case 0:
            break;
        case 1: {
            auto group_by_expr = _group_by_expr_ctxs[0];
            switch (group_by_expr->root()->type().type) {
            case TYPE_TINYINT: {
                type = _aggr_phase == vectorized::AggrPhase1 ? HashVariantType::Type::phase1_int8
                                                             : HashVariantType::Type::phase2_int8;
                break;
            }
            case TYPE_SMALLINT: {
                type = _aggr_phase == vectorized::AggrPhase1 ? HashVariantType::Type::phase1_int16
                                                             : HashVariantType::Type::phase2_int16;
                break;
            }
            case TYPE_INT: {
                type = _aggr_phase == vectorized::AggrPhase1 ? HashVariantType::Type::phase1_int32
                                                             : HashVariantType::Type::phase2_int32;
                break;
            }
            case TYPE_BIGINT: {
                type = _aggr_phase == vectorized::AggrPhase1 ? HashVariantType::Type::phase1_int64
                                                             : HashVariantType::Type::phase2_int64;
                break;
            }
            case TYPE_DATE: {
                type = _aggr_phase == vectorized::AggrPhase1 ? HashVariantType::Type::phase1_date
                                                             : HashVariantType::Type::phase2_date;
                break;
            }
            case TYPE_DATETIME: {
                type = _aggr_phase == vectorized::AggrPhase1 ? HashVariantType::Type::phase1_timestamp
                                                             : HashVariantType::Type::phase2_timestamp;
                break;
            }
            case TYPE_CHAR:
            case TYPE_VARCHAR: {
                type = _aggr_phase == vectorized::AggrPhase1 ? HashVariantType::Type::phase1_string
                                                             : HashVariantType::Type::phase2_string;
                break;
            }
            default: {
                type = _aggr_phase == vectorized::AggrPhase1 ? HashVariantType::Type::phase1_slice
                                                             : HashVariantType::Type::phase2_slice;
                break;
            }
            }
            break;
        }
        default: {
            type = _aggr_phase == vectorized::AggrPhase1 ? HashVariantType::Type::phase1_slice
                                                         : HashVariantType::Type::phase2_slice;
            break;
        }
        }
    }
    VLOG_ROW << "hash type is "
             << static_cast<typename std::underlying_type<typename HashVariantType::Type>::type>(type);
    hash_variant.init(type);
}

void AggregateBaseOperator::_compute_single_agg_state(size_t chunk_size) {
    for (size_t i = 0; i < _agg_fn_ctxs.size(); i++) {
        if (!_is_merge_funcs[i]) {
            _agg_functions[i]->update_batch_single_state(_agg_fn_ctxs[i], chunk_size, _agg_input_raw_columns[i].data(),
                                                         _single_agg_state + _agg_states_offsets[i]);
        } else {
            DCHECK_EQ(_agg_intput_columns[i].size(), 1);
            _agg_functions[i]->merge_batch_single_state(_agg_fn_ctxs[i], chunk_size, _agg_intput_columns[i][0].get(),
                                                        _single_agg_state + _agg_states_offsets[i]);
        }
    }
}

void AggregateBaseOperator::_compute_batch_agg_states(size_t chunk_size) {
    for (size_t i = 0; i < _agg_fn_ctxs.size(); i++) {
        if (!_is_merge_funcs[i]) {
            _agg_functions[i]->update_batch(_agg_fn_ctxs[i], chunk_size, _agg_states_offsets[i],
                                            _agg_input_raw_columns[i].data(), _tmp_agg_states.data());
        } else {
            DCHECK_EQ(_agg_intput_columns[i].size(), 1);
            _agg_functions[i]->merge_batch(_agg_fn_ctxs[i], _agg_intput_columns[i][0]->size(), _agg_states_offsets[i],
                                           _agg_intput_columns[i][0].get(), _tmp_agg_states.data());
        }
    }
}

void AggregateBaseOperator::_compute_batch_agg_states(size_t chunk_size, const std::vector<uint8_t>& selection) {
    for (size_t i = 0; i < _agg_fn_ctxs.size(); i++) {
        _agg_functions[i]->update_batch_selectively(_agg_fn_ctxs[i], chunk_size, _agg_states_offsets[i],
                                                    _agg_input_raw_columns[i].data(), _tmp_agg_states.data(),
                                                    selection);
    }
}

void AggregateBaseOperator::_evaluate_group_by_exprs(vectorized::Chunk* chunk) {
    SCOPED_TIMER(_expr_release_timer);
    for (size_t i = 0; i < _group_by_expr_ctxs.size(); i++) {
        _group_by_columns[i] = nullptr;
    }

    for (size_t i = 0; i < _agg_fn_ctxs.size(); i++) {
        for (size_t j = 0; j < _agg_expr_ctxs[i].size(); j++) {
            _agg_intput_columns[i][j] = nullptr;
            _agg_input_raw_columns[i][j] = nullptr;
        }
    }
}
void AggregateBaseOperator::_evaluate_agg_fn_exprs(vectorized::Chunk* chunk) {
    SCOPED_TIMER(_expr_compute_timer);
    // Compute group by columns
    for (size_t i = 0; i < _group_by_expr_ctxs.size(); i++) {
        _group_by_columns[i] = _group_by_expr_ctxs[i]->evaluate(chunk);
        DCHECK(_group_by_columns[i] != nullptr);
        if (_group_by_columns[i]->is_constant()) {
            // If group by column is constant, we disable streaming aggregate.
            // Because we don't want to send const column to exchange node
            _streaming_preaggregation_mode = TStreamingPreaggregationMode::FORCE_PREAGGREGATION;
            // All hash table could handle only null, and we don't know the real data
            // type for only null column, so we don't unpack it.
            if (!_group_by_columns[i]->only_null()) {
                vectorized::ConstColumn* const_column =
                        static_cast<vectorized::ConstColumn*>(_group_by_columns[i].get());
                const_column->data_column()->assign(chunk->num_rows(), 0);
                _group_by_columns[i] = const_column->data_column();
            }
        }
        // Scalar function compute will return non-nullable column
        // for nullable column when the real whole chunk data all not-null.
        if (_group_by_types[i].is_nullable && !_group_by_columns[i]->is_nullable()) {
            // TODO: optimized the memory usage
            _group_by_columns[i] = vectorized::NullableColumn::create(
                    _group_by_columns[i], vectorized::NullColumn::create(_group_by_columns[i]->size(), 0));
        }
    }

    // Compute agg function columns
    for (size_t i = 0; i < _agg_fn_ctxs.size(); i++) {
        for (size_t j = 0; j < _agg_expr_ctxs[i].size(); j++) {
            // For simplicity and don't change the overall processing flow,
            // We handle const column as normal data column
            // TODO(kks): improve const column aggregate later
            if (j == 0) {
                _agg_intput_columns[i][j] = vectorized::ColumnHelper::unpack_and_duplicate_const_column(
                        chunk->num_rows(), _agg_expr_ctxs[i][j]->evaluate(chunk));
            } else {
                _agg_intput_columns[i][j] = _agg_expr_ctxs[i][j]->evaluate(chunk);
            }
            _agg_input_raw_columns[i][j] = _agg_intput_columns[i][j].get();
        }
    }
}

// When need finalize, create column by result type
// otherwise, create column by serde type
vectorized::Columns AggregateBaseOperator::_create_agg_result_columns() {
    vectorized::Columns agg_result_columns(_agg_fn_types.size());
    if (_needs_finalize) {
        for (size_t i = 0; i < _agg_fn_types.size(); ++i) {
            // For count, count distinct, bitmap_union_int such as never return null function,
            // we need to create a not-nullable column.
            agg_result_columns[i] = vectorized::ColumnHelper::create_column(
                    _agg_fn_types[i].result_type, _agg_fn_types[i].has_nullable_child & _agg_fn_types[i].is_nullable);
            agg_result_columns[i]->reserve(config::vector_chunk_size);
        }
    } else {
        for (size_t i = 0; i < _agg_fn_types.size(); ++i) {
            agg_result_columns[i] = vectorized::ColumnHelper::create_column(_agg_fn_types[i].serde_type,
                                                                            _agg_fn_types[i].has_nullable_child);
            agg_result_columns[i]->reserve(config::vector_chunk_size);
        }
    }
    return agg_result_columns;
}

vectorized::Columns AggregateBaseOperator::_create_group_by_columns() {
    vectorized::Columns group_by_columns(_group_by_types.size());
    for (size_t i = 0; i < _group_by_types.size(); ++i) {
        group_by_columns[i] =
                vectorized::ColumnHelper::create_column(_group_by_types[i].result_type, _group_by_types[i].is_nullable);
        group_by_columns[i]->reserve(config::vector_chunk_size);
    }
    return group_by_columns;
}

void AggregateBaseOperator::_convert_to_chunk_no_groupby(vectorized::ChunkPtr* chunk) {
    // TODO(kks): we should approve memory allocate here
    vectorized::Columns agg_result_column = _create_agg_result_columns();

    if (_needs_finalize) {
        _finalize_to_chunk(_single_agg_state, agg_result_column);
    } else {
        _serialize_to_chunk(_single_agg_state, agg_result_column);
    }

    // For agg function column is non-nullable and table is empty
    // sum(zero_row) should be null, not 0.
    if (UNLIKELY(_num_input_rows == 0 && _group_by_expr_ctxs.empty() && _needs_finalize)) {
        for (size_t i = 0; i < _agg_fn_types.size(); i++) {
            if (_agg_fn_types[i].is_nullable) {
                agg_result_column[i] = vectorized::ColumnHelper::create_column(_agg_fn_types[i].result_type, true);
                agg_result_column[i]->append_default();
            }
        }
    }

    TupleDescriptor* tuple_desc = nullptr;
    if (_needs_finalize) {
        tuple_desc = _output_tuple_desc;
    } else {
        tuple_desc = _intermediate_tuple_desc;
    }

    ChunkPtr result_chunk = std::make_shared<vectorized::Chunk>();
    for (size_t i = 0; i < agg_result_column.size(); i++) {
        result_chunk->append_column(std::move(agg_result_column[i]), tuple_desc->slots()[i]->id());
    }
    ++_num_rows_returned;
    *chunk = std::move(result_chunk);
    _is_ht_done = true;
}

void AggregateBaseOperator::_serialize_to_chunk(vectorized::ConstAggDataPtr state,
                                                const vectorized::Columns& agg_result_columns) {
    for (size_t i = 0; i < _agg_fn_ctxs.size(); i++) {
        _agg_functions[i]->serialize_to_column(_agg_fn_ctxs[i], state + _agg_states_offsets[i],
                                               agg_result_columns[i].get());
    }
}

void AggregateBaseOperator::_finalize_to_chunk(vectorized::ConstAggDataPtr state,
                                               const vectorized::Columns& agg_result_columns) {
    for (size_t i = 0; i < _agg_fn_ctxs.size(); i++) {
        _agg_functions[i]->finalize_to_column(_agg_fn_ctxs[i], state + _agg_states_offsets[i],
                                              agg_result_columns[i].get());
    }
}

Status AggregateBaseOperator::close(RuntimeState* state) {
    for (auto ctx : _agg_fn_ctxs) {
        if (ctx != nullptr && ctx->impl()) {
            ctx->impl()->close();
        }
    }

    // _mem_pool is nullptr means prepare phase failed
    if (_mem_pool != nullptr) {
        // Note: we must free agg_states object before _mem_pool free_all;
        if (_single_agg_state != nullptr) {
            for (int i = 0; i < _agg_functions.size(); i++) {
                _agg_functions[i]->destroy(_single_agg_state + _agg_states_offsets[i]);
            }
        } else if (!_is_only_group_by_columns) {
            if (false) {
            }
#define HASH_MAP_METHOD(NAME)                                                  \
    else if (_hash_map_variant.type == vectorized::HashMapVariant::Type::NAME) \
            _release_agg_memory<decltype(_hash_map_variant.NAME)::element_type>(*_hash_map_variant.NAME);
            APPLY_FOR_VARIANT_ALL(HASH_MAP_METHOD)
#undef HASH_MAP_METHOD
        }

        _mem_pool->free_all();
    }

    get_memtracker()->release(_last_agg_func_memory_usage);
    get_memtracker()->release(_last_ht_memory_usage);

    Expr::close(_group_by_expr_ctxs, state);
    for (const auto& i : _agg_expr_ctxs) {
        Expr::close(i, state);
    }

    return Operator::close(state);
}

void AggregateBaseOperator::_output_chunk_by_streaming(vectorized::ChunkPtr* chunk) {
    ChunkPtr result_chunk = std::make_shared<vectorized::Chunk>();
    for (size_t i = 0; i < _group_by_columns.size(); i++) {
        result_chunk->append_column(_group_by_columns[i], _intermediate_tuple_desc->slots()[i]->id());
    }

    if (!_agg_fn_ctxs.empty()) {
        vectorized::Columns agg_result_column = _create_agg_result_columns();
        for (size_t i = 0; i < _agg_fn_ctxs.size(); i++) {
            size_t id = _group_by_columns.size() + i;
            _agg_functions[i]->convert_to_serialize_format(_agg_intput_columns[i], result_chunk->num_rows(),
                                                           &agg_result_column[i]);
            result_chunk->append_column(std::move(agg_result_column[i]), _intermediate_tuple_desc->slots()[id]->id());
        }
    }
    _num_pass_through_rows += result_chunk->num_rows();
    _num_rows_returned += result_chunk->num_rows();
    *chunk = std::move(result_chunk);
    COUNTER_SET(_pass_through_row_count, _num_pass_through_rows);
}

void AggregateBaseOperator::_output_chunk_by_streaming(vectorized::ChunkPtr* chunk,
                                                       const std::vector<uint8_t>& filter) {
    // Streaming aggregate at least has one group by column
    size_t chunk_size = _group_by_columns[0]->size();
    for (auto& _group_by_column : _group_by_columns) {
        // Multi GroupColumn may be have the same SharedPtr
        // If ColumnSize and ChunkSize are not equal,
        // indicating that the Filter has been executed in previous GroupByColumn
        // e.g.: select c1, cast(c1 as int) from t1 group by c1, cast(c1 as int);

        // At present, the type of problem cannot be completely solved,
        // and a new solution needs to be designed to solve it completely
        if (_group_by_column->size() == chunk_size) {
            _group_by_column->filter(filter);
        }
    }
    for (size_t i = 0; i < _agg_fn_ctxs.size(); i++) {
        for (auto& agg_input_column : _agg_intput_columns[i]) {
            // AggColumn and GroupColumn may be the same SharedPtr,
            // If ColumnSize and ChunkSize are not equal,
            // indicating that the Filter has been executed in GroupByColumn
            // e.g.: select c1, count(distinct c1) from t1 group by c1;

            // At present, the type of problem cannot be completely solved,
            // and a new solution needs to be designed to solve it completely
            if (agg_input_column->size() == chunk_size) {
                agg_input_column->filter(filter);
            }
        }
    }
    _output_chunk_by_streaming(chunk);
}

Status AggregateBaseOperator::_check_hash_map_memory_usage(RuntimeState* state) {
    if ((_num_input_rows & memory_check_batch_size) < config::vector_chunk_size) {
        int64_t delta_memory_usage = static_cast<int64_t>(_hash_map_variant.memory_usage()) - _last_ht_memory_usage;
        get_memtracker()->consume(delta_memory_usage);
        _last_ht_memory_usage = _hash_map_variant.memory_usage();

        int64_t agg_func_memory_usage = 0;
        for (auto& _agg_fn_ctx : _agg_fn_ctxs) {
            agg_func_memory_usage += _agg_fn_ctx->impl()->mem_usage();
        }
        get_memtracker()->consume(agg_func_memory_usage - _last_agg_func_memory_usage);
        _last_agg_func_memory_usage = agg_func_memory_usage;

        RETURN_IF_ERROR(state->check_query_state("Aggregation Node"));
    }
    return Status::OK();
}

Status AggregateBaseOperator::_check_hash_set_memory_usage(RuntimeState* state) {
    if ((_num_input_rows & memory_check_batch_size) < config::vector_chunk_size) {
        int64_t delta_memory_usage = static_cast<int64_t>(_hash_set_variant.memory_usage()) - _last_ht_memory_usage;
        get_memtracker()->consume(delta_memory_usage);
        _last_ht_memory_usage = _hash_set_variant.memory_usage();

        RETURN_IF_ERROR(state->check_query_state("Aggregation Node"));
    }
    return Status::OK();
}

void AggregateBaseOperator::_try_convert_to_two_level_map() {
    if (_last_ht_memory_usage > two_level_memory_threshold) {
        if (_hash_map_variant.type == vectorized::HashMapVariant::Type::phase1_slice) {
            _hash_map_variant.phase1_slice_two_level =
                    std::make_unique<vectorized::SerializedKeyTwoLevelAggHashMap<vectorized::PhmapSeed1>>();

            _hash_map_variant.phase1_slice_two_level->hash_map.reserve(
                    _hash_map_variant.phase1_slice->hash_map.capacity());

            _hash_map_variant.phase1_slice_two_level->hash_map.insert(_hash_map_variant.phase1_slice->hash_map.begin(),
                                                                      _hash_map_variant.phase1_slice->hash_map.end());

            _hash_map_variant.type = vectorized::HashMapVariant::Type::phase1_slice_two_level;
            _hash_map_variant.phase1_slice.reset();
        } else if (_hash_map_variant.type == vectorized::HashMapVariant::Type::phase2_slice) {
            _hash_map_variant.phase2_slice_two_level =
                    std::make_unique<vectorized::SerializedKeyTwoLevelAggHashMap<vectorized::PhmapSeed2>>();

            _hash_map_variant.phase2_slice_two_level->hash_map.reserve(
                    _hash_map_variant.phase2_slice->hash_map.capacity());

            _hash_map_variant.phase2_slice_two_level->hash_map.insert(_hash_map_variant.phase2_slice->hash_map.begin(),
                                                                      _hash_map_variant.phase2_slice->hash_map.end());

            _hash_map_variant.type = vectorized::HashMapVariant::Type::phase2_slice_two_level;
            _hash_map_variant.phase2_slice.reset();
        }
    }
}

bool AggregateBaseOperator::_should_expand_preagg_hash_tables(size_t input_chunk_size, int64_t ht_mem,
                                                              int64_t ht_rows) const {
    // Need some rows in tables to have valid statistics.
    if (ht_rows == 0) {
        return true;
    }

    // Find the appropriate reduction factor in our table for the current hash table sizes.
    int cache_level = 0;
    while (cache_level + 1 < vectorized::STREAMING_HT_MIN_REDUCTION_SIZE &&
           ht_mem >= vectorized::STREAMING_HT_MIN_REDUCTION[cache_level + 1].min_ht_mem) {
        cache_level++;
    }

    // Compare the number of rows in the hash table with the number of input rows that
    // were aggregated into it. Exclude passed through rows from this calculation since
    // they were not in hash tables.
    // TODO(hcf) remove this logic of _prev_num_rows_returned
    const int64_t input_rows = _prev_num_rows_returned - input_chunk_size;
    const int64_t aggregated_input_rows = input_rows - _num_rows_returned;
    double current_reduction = static_cast<double>(aggregated_input_rows) / ht_rows;

    // inaccurate, which could lead to a divide by zero below.
    if (aggregated_input_rows <= 0) {
        return true;
    }
    // Extrapolate the current reduction factor (r) using the formula
    // R = 1 + (N / n) * (r - 1), where R is the reduction factor over the full input data
    // set, N is the number of input rows, excluding passed-through rows, and n is the
    // number of rows inserted or merged into the hash tables. This is a very rough
    // approximation but is good enough to be useful.
    double min_reduction = vectorized::STREAMING_HT_MIN_REDUCTION[cache_level].streaming_ht_min_reduction;
    return current_reduction > min_reduction;
}
} // namespace starrocks::pipeline